"""Create figures for Simulation A."""
from __future__ import annotations
import argparse, math
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def fig_v_vs_d(df, outp: Path):
    fig = plt.figure()
    ax = fig.gca()
    ax.scatter(df["D"], df["V"], s=20)
    t = np.linspace(0, 1, 400)
    ax.plot(t, np.sqrt(1 - t*t), linestyle="--")
    ax.set_xlabel("D (distinguishability)")
    ax.set_ylabel("V (visibility)")
    ax.set_title("Present‑Act V2.1 — Visibility vs Distinguishability")
    fig.savefig(outp, dpi=160, bbox_inches="tight")
    plt.close(fig)

def fig_vd_vs_m(df, outp: Path):
    med = df.groupby("m", as_index=False)[["V","D"]].median()
    fig = plt.figure()
    ax = fig.gca()
    ax.plot(med["m"], med["V"], marker="o")
    ax.plot(med["m"], med["D"], marker="o")
    ax.set_xlabel("overlap m")
    ax.set_ylabel("value")
    ax.set_title("Present‑Act V2.1 — V(m) and D(m) (median across seeds)")
    ax.legend(["V(m)", "D(m)"])
    fig.savefig(outp, dpi=160, bbox_inches="tight")
    plt.close(fig)

def fig_ablation(df_main, df_ab, outp: Path):
    fig = plt.figure()
    ax = fig.gca()
    band_lo = np.ones(9) * (1 - 0.05)
    band_hi = np.ones(9) * (1 + 0.05)
    ax.fill_between(range(9), band_lo, band_hi, alpha=0.15)
    base_med = df_main.groupby("m")["V2_plus_D2"].median()
    ax.plot(range(9), base_med.values, marker="o")
    if df_ab is not None and len(df_ab):
        for name, grp in df_ab.groupby("ablation"):
            mvals = grp.groupby("m")["V2_plus_D2"].median()
            mvals = mvals.reindex(range(9))
            ax.plot(range(9), mvals.values, marker="o")
        ax.legend(["baseline", *sorted(df_ab["ablation"].unique())], loc="best")
    ax.set_xlabel("overlap m")
    ax.set_ylabel("median V^2 + D^2")
    ax.set_ylim(0.8, 1.1)
    ax.set_title("Ablations — effect on V^2 + D^2 (median)")
    fig.savefig(outp, dpi=160, bbox_inches="tight")
    plt.close(fig)

def main(summary_csv: str, ablation_csv: str | None, out_dir: str):
    outp = Path(out_dir); outp.mkdir(parents=True, exist_ok=True)
    df = pd.read_csv(summary_csv)
    fig_v_vs_d(df, outp / "simA_v_vs_d.png")
    fig_vd_vs_m(df, outp / "simA_V_D_vs_m.png")
    df_ab = pd.read_csv(ablation_csv) if ablation_csv else None
    fig_ablation(df, df_ab, outp / "simA_ablation.png")

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--summary", required=True)
    ap.add_argument("--ablation", required=False, default=None)
    ap.add_argument("--output_dir", required=True)
    args = ap.parse_args()
    main(args.summary, args.ablation, args.output_dir)
